from typing import NamedTuple, Optional

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.agents.multi_agent import MultiAgentLearner, MultiAgent
from centralized_verification.shields.shield import Shield
from centralized_verification.utils import TrainingProgress


class TrainingLimits(NamedTuple):
    max_total_steps: Optional[int] = None
    max_episode_len: Optional[int] = None
    max_num_episodes: Optional[int] = None

    def should_stop_training(self, train_progress: TrainingProgress):
        return (self.max_total_steps is not None and train_progress.global_step_count >= self.max_total_steps) or (
                self.max_num_episodes is not None and train_progress.global_episode_count >= self.max_num_episodes)

    def is_at_logging_interval(self, train_progress: TrainingProgress, num_log_intervals: int, end_of_episode: bool):
        if self.max_total_steps is not None:
            log_every_num_steps = int(self.max_total_steps / num_log_intervals)
            if train_progress.global_step_count % log_every_num_steps == 0:
                return True

        if self.max_num_episodes is not None:
            log_every_num_eps = int(self.max_num_episodes / num_log_intervals)
            if train_progress.global_episode_count % log_every_num_eps == 0 and end_of_episode:
                return True

        return False


class TestingLimits(NamedTuple):
    max_episode_len: int
    num_episodes: int


class Configuration(NamedTuple):
    shield: Shield
    env: MultiAgentSafetyEnv
    learner: MultiAgentLearner
    run_name: str
    limits: TrainingLimits
    num_log_entries: int
    num_checkpoints: int


class TestConfiguration(NamedTuple):
    shield: Shield
    env: MultiAgentSafetyEnv
    agent: MultiAgent
    run_name: str
    limits: TestingLimits
